"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from functools import reduce

import numpy

import oneflow._oneflow_internal
from oneflow._oneflow_internal.oneflow.core.register import logical_blob_id as lbi_util
from oneflow.compatible import single_client as flow
from oneflow.compatible.single_client.eager import boxing_util as boxing_util
from oneflow.compatible.single_client.framework import (
    balanced_splitter as balanced_splitter,
)
from oneflow.compatible.single_client.framework import dtype as dtype_util
from oneflow.compatible.single_client.framework import id_util as id_util
from oneflow.compatible.single_client.framework import input_blob_def as input_blob_def
from oneflow.compatible.single_client.framework import (
    python_callback as python_callback,
)
from oneflow.compatible.single_client.framework import remote_blob as remote_blob_util
from oneflow.core.operator import op_conf_pb2 as op_conf_util
from oneflow.core.register import logical_blob_id_pb2 as logical_blob_id_util

blob_register = oneflow._oneflow_internal.GetDefaultBlobRegister()


def AsyncPush(session, job_func, *arg):
    assert len(arg) == len(job_func.__oneflow_input_blob_defs__)
    for i in range(len(arg)):
        _AsyncPushArg(session, job_func.__oneflow_input_blob_defs__[i], arg[i])


def _AsyncPushArg(session, arg_blob_def, arg_ndarray):
    if isinstance(arg_blob_def, (list, tuple)):
        assert isinstance(arg_ndarray, (list, tuple)), "type(arg_ndarray): %s" % type(
            arg_ndarray
        )
        assert len(arg_blob_def) == len(arg_ndarray), "%s v.s. %s" % (
            len(arg_blob_def),
            len(arg_ndarray),
        )
        for (blob_def, ndarray) in zip(arg_blob_def, arg_ndarray):
            _AsyncPushArg(session, blob_def, ndarray)
    elif isinstance(arg_blob_def, dict):
        assert type(arg_blob_def) is type(arg_ndarray)
        assert set(arg_blob_def.keys()) == set(arg_ndarray.keys())
        for (k, blob_def) in arg_blob_def.items():
            _AsyncPushArg(session, blob_def, arg_ndarray[k])
    else:
        assert isinstance(arg_blob_def, input_blob_def.ArgBlobDef)
        arg_blob_def.CheckAndAsyncPush(session, arg_ndarray)


def MakeEagerInputBlobs(arg_blob_def, arg_ndarray):
    if isinstance(arg_blob_def, (list, tuple)):
        assert isinstance(arg_ndarray, (list, tuple)), "type(arg_ndarray): %s" % type(
            arg_ndarray
        )
        assert len(arg_blob_def) == len(arg_ndarray)
        return type(arg_blob_def)(
            (
                MakeEagerInputBlobs(blob_def, ndarray)
                for (blob_def, ndarray) in zip(arg_blob_def, arg_ndarray)
            )
        )
    elif isinstance(arg_blob_def, dict):
        assert type(arg_blob_def) is type(arg_ndarray)
        assert set(arg_blob_def.keys()) == set(arg_ndarray.keys())
        return {
            k: MakeEagerInputBlobs(blob_def, arg_ndarray[k])
            for (k, blob_def) in arg_blob_def.items()
        }
    else:
        return _CreateEagerInputBlobAndFeedValue(arg_blob_def, arg_ndarray)


def _CheckInputArgBlobDefValueMatch(arg_blob_def, arg_value):
    if isinstance(arg_blob_def, input_blob_def.FixedTensorDef):
        assert isinstance(arg_value, numpy.ndarray)
        assert arg_blob_def.shape == arg_value.shape
    elif isinstance(arg_blob_def, input_blob_def.MirroredTensorDef):
        assert isinstance(arg_value, (list, tuple))
        for v in arg_value:
            assert isinstance(v, numpy.ndarray)
            assert len(v.shape) == len(arg_blob_def.shape)
            assert numpy.prod(v.shape) <= numpy.prod(arg_blob_def.shape)
    else:
        raise NotImplementedError


def FeedValueToEagerBlob(blob_object, blob_def, ndarray):
    physical_blob_objects = _GetPhysicalBlobObjects(blob_object, None)
    feed_ctx = FeedContext(blob_object.op_arg_parallel_attr, ndarray)
    for (i, physical_blob_object) in enumerate(physical_blob_objects):
        feed_ctx.set_rank(i)
        _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, physical_blob_object)


def _CreateEagerInputBlobAndFeedValue(arg_blob_def, arg_ndarray):
    _CheckInputArgBlobDefValueMatch(arg_blob_def, arg_ndarray)
    (arg_blob_object, lbi) = _MakeInputBlobObject(arg_blob_def)
    FeedValueToEagerBlob(arg_blob_object, arg_blob_def, arg_ndarray)
    get_blob = None
    if not isinstance(lbi, lbi_util.LogicalBlobId):
        cfg_lbi = lbi_util.LogicalBlobId()
        cfg_lbi.set_op_name(lbi.op_name)
        cfg_lbi.set_blob_name(lbi.blob_name)
        lbi = cfg_lbi
    if isinstance(arg_blob_def, input_blob_def.FixedTensorDef):

        def get_blob(lbi, blob_object, blob_register):
            blob = oneflow._oneflow_internal.EagerConsistentBlob(
                lbi, blob_object, blob_register
            )
            with flow.scope.consistent_view():
                return flow.identity(blob)

    elif isinstance(arg_blob_def, input_blob_def.MirroredTensorDef):
        get_blob = oneflow._oneflow_internal.EagerMirroredBlob
    else:
        raise NotImplementedError
    return get_blob(lbi, blob_object=arg_blob_object, blob_register=blob_register)


def _MakeInputBlobObject(arg_blob_def):
    (input_op_conf, lbi) = _MakeInputOpConfAndRetLbi(arg_blob_def)
    bn_in_op2blob_object = oneflow._oneflow_internal.deprecated.BnInOp2BlobObject()

    def BuildInputInstruction(builder):
        op_attribute = arg_blob_def.EagerAddAndInferOp(input_op_conf)
        scope = flow.current_scope()
        parallel_conf = scope.device_parallel_desc_symbol.parallel_conf
        cfg_op_attribute = oneflow._oneflow_internal.deprecated.MakeOpAttributeByString(
            str(op_attribute)
        )
        builder.StatelessCall(
            cfg_op_attribute, parallel_conf, bn_in_op2blob_object, boxing_util.BoxingTo
        )

    oneflow._oneflow_internal.deprecated.LogicalRun(BuildInputInstruction)
    return (bn_in_op2blob_object["out"], lbi)


def _GetPhysicalBlobObjects(logical_blob_object, lbi):
    blob_register = oneflow._oneflow_internal.GetDefaultBlobRegister()
    physical_blob_objects = None

    def BuildLogical2PhysicalInstruction(builder):
        nonlocal physical_blob_objects
        physical_blob_objects = builder.UnpackLogicalBlobToPhysicalBlobs(
            logical_blob_object
        )

    oneflow._oneflow_internal.deprecated.LogicalRun(BuildLogical2PhysicalInstruction)
    return physical_blob_objects


def _MakeInputOpConfAndRetLbi(arg_blob_def):
    assert isinstance(arg_blob_def, input_blob_def.ArgBlobDef)
    op_conf = op_conf_util.OperatorConf()
    op_conf.name = id_util.UniqueStr("Input_")
    op_conf.input_conf.out = "out"
    op_conf.input_conf.blob_conf.CopyFrom(arg_blob_def.ToInterfaceBlobConf())
    lbi = logical_blob_id_util.LogicalBlobId()
    lbi.op_name = op_conf.name
    lbi.blob_name = op_conf.input_conf.out
    return (op_conf, lbi)


class FeedContext(object):
    def __init__(self, op_arg_parallel_attr, arg_ndarray, rank=0):
        self.op_arg_parallel_attr_ = op_arg_parallel_attr
        self.arg_ndarray_ = arg_ndarray
        self.rank_ = rank
        self.balanced_range_ = None

    def set_rank(self, rank):
        self.rank_ = rank

    def GetFixedTensor(self, logical_shape):
        assert isinstance(self.arg_ndarray_, numpy.ndarray)
        assert self.arg_ndarray_.shape == logical_shape, "%s v.s. %s" % (
            self.arg_ndarray_.shape,
            logical_shape,
        )
        sbp_parallel = self.op_arg_parallel_attr_.sbp_parallel
        parallel_num = self.op_arg_parallel_attr_.parallel_desc_symbol.parallel_num
        if sbp_parallel.has_broadcast_parallel() or parallel_num == 1:
            return self._AsContiguousNdArray(self.arg_ndarray_)
        elif sbp_parallel.has_split_parallel():
            axis = sbp_parallel.split_parallel().axis()
            (start, end) = self._GetBalancedRanges(logical_shape[axis])[self.rank_]
            slc = [slice(None)] * len(logical_shape)
            slc[axis] = slice(start, end)
            ndarray = self.arg_ndarray_[tuple(slc)]
            return self._AsContiguousNdArray(ndarray)
        else:
            raise NotImplementedError

    def _GetBalancedRanges(self, dim):
        parallel_num = self.op_arg_parallel_attr_.parallel_desc_symbol.parallel_num
        if self.balanced_range_ is None:
            self.balanced_range_ = balanced_splitter.BalancedRanges(dim, parallel_num)
        return self.balanced_range_

    def GetMirroredTensor(self, static_shape):
        capacity = reduce(lambda x, y: x * y, static_shape, 1)
        assert isinstance(self.arg_ndarray_, (list, tuple))
        parallel_num = self.op_arg_parallel_attr_.parallel_desc_symbol.parallel_num
        assert len(self.arg_ndarray_) == parallel_num
        assert all((isinstance(a, numpy.ndarray) for a in self.arg_ndarray_))
        assert self.rank_ >= 0
        assert self.rank_ < parallel_num
        ndarray = self.arg_ndarray_[self.rank_]
        elem_cnt = reduce(lambda x, y: x * y, ndarray.shape, 1)
        assert elem_cnt <= capacity, "%s v.s. %s" % (ndarray.shape, static_shape)
        return self._AsContiguousNdArray(ndarray)

    def _AsContiguousNdArray(self, ndarray):
        if isinstance(ndarray, numpy.ndarray):
            return (
                ndarray
                if ndarray.flags["C_CONTIGUOUS"]
                else numpy.ascontiguousarray(ndarray)
            )
        elif isinstance(ndarray, (tuple, list)):
            return type(ndarray)((self._AsContiguousNdArray(a) for a in ndarray))
        else:
            raise NotImplementedError


def _FeedValueToInputPhysicalBlob(feed_ctx, blob_def, blob_object):
    assert isinstance(blob_def, input_blob_def.ArgBlobDef)
    assert isinstance(blob_object, oneflow._oneflow_internal.BlobObject)
    FeedBlob = _MakeFeedBlobCallback(feed_ctx, blob_def, blob_object)
    assert callable(FeedBlob)

    def BuildFeedInstruction(builder):
        builder.FeedBlob(
            blob_object, python_callback.GetIdForRegisteredCallback(FeedBlob)
        )
        builder.InsertRemoveForeignCallbackInstruction(
            blob_object.object_id, python_callback.GetIdForRegisteredCallback(FeedBlob)
        )

    oneflow._oneflow_internal.deprecated.PhysicalRun(BuildFeedInstruction)


def _MakeFeedBlobCallback(feed_ctx, blob_def, blob_object):
    if isinstance(blob_def, input_blob_def.FixedTensorDef):

        def FeedBlob(ofblob):
            ndarray = feed_ctx.GetFixedTensor(blob_def.shape)
            dtype = dtype_util.convert_oneflow_dtype_to_numpy_dtype(ofblob.dtype)
            assert ndarray.dtype == dtype, "%s v.s. %s" % (ndarray.dtype, dtype)
            assert ndarray.shape == ofblob.static_shape, "%s v.s. %s" % (
                ndarray.shape,
                ofblob.static_shape,
            )
            if ofblob.CopyFromNdarray(ndarray) is False:
                raise ValueError

    elif isinstance(blob_def, input_blob_def.MirroredTensorDef):

        def FeedBlob(ofblob):
            ndarray = feed_ctx.GetMirroredTensor(ofblob.static_shape)
            assert isinstance(ndarray, numpy.ndarray)
            dtype = dtype_util.convert_oneflow_dtype_to_numpy_dtype(ofblob.dtype)
            assert ndarray.dtype == dtype, "%s v.s. %s" % (ndarray.dtype, dtype)
            if ofblob.CopyFromNdarray(ndarray) is False:
                raise ValueError

    else:
        raise NotImplementedError
    return FeedBlob
